{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference in Discrete Bayesian Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we show a simple example for doing Exact inference in Bayesian Networks using pgmpy. We will be using the Asia network (http://www.bnlearn.com/bnrepository/#asia) for this example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Define the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fetch the asia model from the bnlearn repository\n",
    "\n",
    "from pgmpy.utils import get_example_model\n",
    "\n",
    "asia_model = get_example_model(\"asia\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Nodes:  ['asia', 'tub', 'smoke', 'lung', 'bronc', 'either', 'xray', 'dysp']\n",
      "Edges:  [('asia', 'tub'), ('tub', 'either'), ('smoke', 'lung'), ('smoke', 'bronc'), ('lung', 'either'), ('bronc', 'dysp'), ('either', 'xray'), ('either', 'dysp')]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[<TabularCPD representing P(asia:2) at 0x7f08a40e6a90>,\n",
       " <TabularCPD representing P(bronc:2 | smoke:2) at 0x7f08a40e6dc0>,\n",
       " <TabularCPD representing P(dysp:2 | bronc:2, either:2) at 0x7f08a40fa730>,\n",
       " <TabularCPD representing P(either:2 | lung:2, tub:2) at 0x7f08a40fa100>,\n",
       " <TabularCPD representing P(lung:2 | smoke:2) at 0x7f08a40fa790>,\n",
       " <TabularCPD representing P(smoke:2) at 0x7f08a40fa5e0>,\n",
       " <TabularCPD representing P(tub:2 | asia:2) at 0x7f08a40fac40>,\n",
       " <TabularCPD representing P(xray:2 | either:2) at 0x7f08a40fab80>]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Nodes: \", asia_model.nodes())\n",
    "print(\"Edges: \", asia_model.edges())\n",
    "asia_model.get_cpds()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you would like to create a model from scratch, please refer to the Creating Bayesian Networks notebook: https://github.com/pgmpy/pgmpy/blob/dev/examples/Creating%20a%20Bayesian%20Network.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Initialize the inference class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Currently, pgmpy support two algorithms for inference: 1. Variable Elimination and, 2. Belief Propagation. Both of these are exact inference algorithms. The following example uses `VariableElimination` but `BeliefPropagation` has an identifcal API, so all the methods show below would also work for `BeliefPropagation`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initializing the VariableElimination class\n",
    "\n",
    "from pgmpy.inference import VariableElimination\n",
    "\n",
    "asia_infer = VariableElimination(asia_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Doing Inference using hard evidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finding Elimination Order: : : 0it [00:00, ?it/s]\n",
      "0it [00:00, ?it/s]\u001b[A\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\n",
      "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A\n",
      "\n",
      "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+--------------+\n",
      "| bronc      |   phi(bronc) |\n",
      "+============+==============+\n",
      "| bronc(yes) |       0.3000 |\n",
      "+------------+--------------+\n",
      "| bronc(no)  |       0.7000 |\n",
      "+------------+--------------+\n",
      "+------------+-----------+-------------------+\n",
      "| bronc      | asia      |   phi(bronc,asia) |\n",
      "+============+===========+===================+\n",
      "| bronc(yes) | asia(yes) |            0.0060 |\n",
      "+------------+-----------+-------------------+\n",
      "| bronc(yes) | asia(no)  |            0.5940 |\n",
      "+------------+-----------+-------------------+\n",
      "| bronc(no)  | asia(yes) |            0.0040 |\n",
      "+------------+-----------+-------------------+\n",
      "| bronc(no)  | asia(no)  |            0.3960 |\n",
      "+------------+-----------+-------------------+\n",
      "+------------+--------------+\n",
      "| bronc      |   phi(bronc) |\n",
      "+============+==============+\n",
      "| bronc(yes) |       0.3000 |\n",
      "+------------+--------------+\n",
      "| bronc(no)  |       0.7000 |\n",
      "+------------+--------------+\n",
      "+-----------+-------------+\n",
      "| asia      |   phi(asia) |\n",
      "+===========+=============+\n",
      "| asia(yes) |      0.0100 |\n",
      "+-----------+-------------+\n",
      "| asia(no)  |      0.9900 |\n",
      "+-----------+-------------+\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Computing the probability of bronc given smoke=no.\n",
    "q = asia_infer.query(variables=[\"bronc\"], evidence={\"smoke\": \"no\"})\n",
    "print(q)\n",
    "\n",
    "# Computing the joint probability of bronc and asia given smoke=yes\n",
    "q = asia_infer.query(variables=[\"bronc\", \"asia\"], evidence={\"smoke\": \"yes\"})\n",
    "print(q)\n",
    "\n",
    "# Computing the probabilities (not joint) of bronc and asia given smoke=no\n",
    "q = asia_infer.query(variables=[\"bronc\", \"asia\"], evidence={\"smoke\": \"no\"}, joint=False)\n",
    "for factor in q.values():\n",
    "    print(factor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'bronc': 'no'}\n",
      "{'bronc': 'yes', 'asia': 'no'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Computing the MAP of bronc given smoke=no.\n",
    "q = asia_infer.map_query(variables=[\"bronc\"], evidence={\"smoke\": \"no\"})\n",
    "print(q)\n",
    "\n",
    "# Computing the MAP of bronc and asia given smoke=yes\n",
    "q = asia_infer.map_query(variables=[\"bronc\", \"asia\"], evidence={\"smoke\": \"yes\"})\n",
    "print(q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5: Inference using virtual evidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finding Elimination Order: : : 0it [00:04, ?it/s]\n",
      "Finding Elimination Order: : : 0it [00:04, ?it/s]\n",
      "Finding Elimination Order: : : 0it [00:02, ?it/s]\n",
      "Finding Elimination Order: : : 0it [00:02, ?it/s]\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\n",
      "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+--------------+\n",
      "| bronc      |   phi(bronc) |\n",
      "+============+==============+\n",
      "| bronc(yes) |       0.3000 |\n",
      "+------------+--------------+\n",
      "| bronc(no)  |       0.7000 |\n",
      "+------------+--------------+\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.factors.discrete import TabularCPD\n",
    "\n",
    "# Get state name, which ensure lung evidence has same state names as asia_model\n",
    "state_names = asia_model.get_cpds('lung').state_names\n",
    "\n",
    "lung_virt_evidence = TabularCPD(variable=\"lung\", variable_card=2, values=[[0.4], [0.6]],state_names=state_names)\n",
    "\n",
    "# Query with hard evidence smoke = no and virtual evidence lung = [0.4, 0.6]\n",
    "q = asia_infer.query(\n",
    "    variables=[\"bronc\"], evidence={\"smoke\": \"no\"}, virtual_evidence=[lung_virt_evidence]\n",
    ")\n",
    "print(q)\n",
    "\n",
    "# Query with hard evidence smoke = no and virtual evidences lung = [0.4, 0.6] and bronc = [0.3, 0.7]\n",
    "lung_virt_evidence = TabularCPD(variable=\"lung\", variable_card=2, values=[[0.4], [0.7]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+------------+-----------+\n",
      "| smoke     | smoke(yes) | smoke(no) |\n",
      "+-----------+------------+-----------+\n",
      "| lung(yes) | 0.1        | 0.01      |\n",
      "+-----------+------------+-----------+\n",
      "| lung(no)  | 0.9        | 0.99      |\n",
      "+-----------+------------+-----------+\n"
     ]
    }
   ],
   "source": [
    "print(asia_model.get_cpds(\"lung\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: Troubleshooting for slow inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the case of large models, or models in which variables have a lot of states, inference can be quite slow. Some of the ways to deal with it are:\n",
    "\n",
    "1. Reduce the number of states of variables by combining states together.\n",
    "2. Try a different elimination order by specifying `elimination_order` argument. Possible options are: MinFill, MinNeighbors, MinWeight, WeightedMinFill. \n",
    "3. Try a custom elimination order: The implemented heuristics for computing the elimination order might not be efficient in every case. If you can think of a more efficient order, you can also pass it as a list to the `elimination_order` argument.\n",
    "4. If it is still too slow, try using approximate inference using sampling algorithms."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
